import os
import argparse
import pandas as pd
import numpy as np

from sklearn.ensemble import (
    GradientBoostingRegressor,
    RandomForestRegressor,
    RandomForestClassifier,
)
from doubleml import DoubleMLPLR, DoubleMLIRM
from doubleml import DoubleMLData
from sklearn.linear_model import LassoCV, LogisticRegressionCV
import time
from helper import model_saver, load_config
from data_utils import impute_missing_data, add_mean_encoding, prepare_dml_data
from typing import Union

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

np.random.seed(123)

# columns with potential missing values to check
MISSING_COLS = [
    "assignee_id_i",
    "assignee_id_j",
    "main_cpc_subclass_i",
    "main_cpc_subclass_j",
    "main_cpc_section_i",
    "main_cpc_section_j",
    "assignee_country_i",
    "assignee_country_j",
]

# map string models from config to ml model classes
MODEL_MAPPING = {
    "RandomForestClassifier": RandomForestClassifier,
    "GradientBoostingRegressor": GradientBoostingRegressor,
    "RandomForestRegressor": RandomForestRegressor,
    "LassoCV": LassoCV,
    "LogisticRegressionCV": LogisticRegressionCV,
    "DoubleMLPLR": DoubleMLPLR,
    "DoubleMLIRM": DoubleMLIRM,
}


def train_doubleml_model(
    dml_data: DoubleMLData,
    doubleml_model,
    outcome_model,
    prediction_model,
):
    """
    train double machine learning model using model configs

    Returns:
        a fitted doubleML model
    """

    dml = doubleml_model(dml_data, outcome_model, prediction_model)
    print("start fitting model...")
    start = time.time()
    dml.fit(n_jobs_cv=-1)
    print("model fitting done...")
    print("time(s) taken fitting model: ", time.time() - start)
    print(dml.summary)
    return dml


def compute_ate(doubleml_model: Union[DoubleMLIRM, DoubleMLPLR]) -> pd.DataFrame:
    return doubleml_model.summary


def compute_gate(
    doubleml_model: DoubleMLIRM, groups: pd.DataFrame, pointwise_ci: bool = True
) -> pd.DataFrame:
    """compute heterogeneous treatment effect"""
    # assert len(groups) == df
    if isinstance(doubleml_model, DoubleMLPLR):
        raise ValueError("model type should be DoubleMLIRM")
    gate = doubleml_model.gate(groups=groups)

    if pointwise_ci:
        return gate.summary
    else:
        return gate.confint(level=0.95, joint=True, n_rep_boot=1000)


def instantiate_model(config: dict) -> tuple:
    """
    Instantiate the dml models based on the model_type
    """
    doubleml_model_type = config["model"]["doubleml_model"]
    if doubleml_model_type not in MODEL_MAPPING:
        raise ValueError(f"Invalid model_type: {doubleml_model_type}")
    doubleml_model = MODEL_MAPPING[doubleml_model_type]

    outcome_model_class = MODEL_MAPPING[config["model"]["outcome_model"]["algorithm"]]
    outcome_model = outcome_model_class(
        **config["model"]["outcome_model"]["hyperparameters"]
    )

    prediction_model_class = MODEL_MAPPING[
        config["model"]["prediction_model"]["algorithm"]
    ]
    prediction_model = prediction_model_class(
        **config["model"]["prediction_model"]["hyperparameters"]
    )

    print("doubleml_model: ", doubleml_model)
    print("outcome_model: ", outcome_model)
    print("prediction_model: ", prediction_model)

    return (doubleml_model, outcome_model, prediction_model)


def parse_arguments():
    # python3 src/doubleml_gender.py config/model_2_config.json
    parser = argparse.ArgumentParser(description="Main script for DML")
    parser.add_argument("--config", type=str, help="Path to the config.json file")
    return parser.parse_args()


def main():
    args = parse_arguments()
    config_path = args.config
    config = load_config(config_path)

    cate_vars = config["data"]["cate_vars"]
    cont_vars = config["data"]["cont_vars"]

    df = pd.read_parquet(config["data"]["data_path"])
    # df = pd.read_csv(config["data"]["data_path"], low_memory=False)
    df = df.dropna(subset=['main_cpc_section_i', 'main_cpc_section_j'])

    # impute missing columns
    impute_missing_data(df, MISSING_COLS)

    # add mean encoded assignee IDs
    if (
        config["data"]["mean_encoding"] == True
        and ("assignee_id_j" in config["data"]["cate_vars"])
        and ("assignee_id_i" in config["data"]["cate_vars"])
    ):
        df = add_mean_encoding(
            df,
            cont_vars=config["data"]["cont_vars"],
            cols_to_encode=["assignee_id_j", "assignee_id_i"],
        )
        # remove the mean encoded assignee IDs from categorical variables list
        cate_vars.remove("assignee_id_j")
        cate_vars.remove("assignee_id_i")

    # prepare dml data
    dml_data = prepare_dml_data(df, cate_vars, cont_vars)

    print("dml data prepared.")

    dml = train_doubleml_model(dml_data, *instantiate_model(config))

    ate = compute_ate(dml)

    # save ate result
    result_save_path = os.path.join(
        config["output"]["model_save_path"], config["model_name"]
    )
    if not os.path.exists(result_save_path):
        os.mkdir(result_save_path)

    ate.to_csv(os.path.join(result_save_path, "ate.csv"), index=False)

    # save hte result
    if config["model"]["doubleml_model"] == "DoubleMLIRM":
        groups = pd.DataFrame(
            np.column_stack(
                [
                    df["allfemale_09_100_i"] == True,  # citing female
                    df["allfemale_09_100_i"] == False,  # citing male
                ]
            ),
            columns=["female_i", "male_i"],
        )
        gate = compute_gate(dml, groups)
        gate.to_csv(os.path.join(result_save_path, "gate.csv"), index=False)

    # # save fitted model
    # model_saver(dml, os.path.join(result_save_path, config["model_name"] + ".pkl"))
    # print("model saved!")


if __name__ == "__main__":
    main()
